# -*-Python-*-
import mesh_tensorflow.transformer.evolved_transformer
import mesh_tensorflow.transformer.transformer_layers


################################################################################
################## IMPORTANT: READ BEFORE USING! ###############################
################################################################################
# One evolved transformer layer is equivalent to two vanilla transfomer layers.
# You'll have to manually halve the number of layers. This number assumes 12
# layers originally. Please override in your own gin config to half your
# original num_layers.
################################################################################

num_layers = 6

encoder/transformer.make_layer_stack.layers = [
    @evolved_transformer.GatedLinearUnitLayer,
    @evolved_transformer.EncoderConvolutionalLayer,
    @transformer_layers.SelfAttention,
    @transformer_layers.DenseReluDense,
]
decoder/transformer.make_layer_stack.layers = [
    @evolved_transformer.DecoderAttentionLayer,
    @evolved_transformer.DecoderConvolutionalLayer,
    @transformer_layers.SelfAttention,
    @transformer_layers.EncDecAttention,
    @transformer_layers.DenseReluDense,
]

evolved_transformer.EncoderConvolutionalLayer.d_model = %d_model
evolved_transformer.EncoderConvolutionalLayer.dropout_rate = %dropout_rate

evolved_transformer.DecoderAttentionLayer.base_num_heads = %num_heads

evolved_transformer.DecoderConvolutionalLayer.d_model = %d_model
evolved_transformer.DecoderConvolutionalLayer.dropout_rate = %dropout_rate

decoder/transformer_layers.DenseReluDense.activation = "swish"
